load("@rules_python//python:defs.bzl", "py_library", "py_test")
load("//bazel:python.bzl", "doctest")

doctest(
    name = "py_doctest[air]",
    env = {"RAY_TRAIN_V2_ENABLED": "1"},
    files = glob(
        ["**/*.py"],
        exclude = glob([
            "examples/**/*",
            "tests/**/*",
            "callbacks/*.py",
        ]) + ["integrations/wandb.py"],
    ),
    tags = ["team:ml"],
)

py_library(
    name = "conftest",
    srcs = ["tests/conftest.py"],
)

# --------------------------------------------------------------------
# Tests from the python/ray/air/tests directory.
# Covers all tests starting with `test_`.
# Please keep these sorted alphabetically.
# --------------------------------------------------------------------

py_test(
    name = "test_air_usage",
    size = "small",
    srcs = ["tests/test_air_usage.py"],
    # NOTE: This tests Train V1 telemetry.
    env = {"RAY_TRAIN_V2_ENABLED": "0"},
    tags = [
        "exclusive",
        "team:ml",
    ],
    deps = [":ml_lib"],
)

py_test(
    name = "test_new_dataset_config",
    size = "large",
    srcs = ["tests/test_new_dataset_config.py"],
    # NOTE: Relevant tests moved to train/v2/tests/test_data_integration.py
    env = {"RAY_TRAIN_V2_ENABLED": "0"},
    tags = [
        "exclusive",
        "team:ml",
    ],
    deps = [":ml_lib"],
)

py_test(
    name = "test_experiment_restore",
    size = "large",
    srcs = [
        "tests/_test_experiment_restore_run.py",
        "tests/test_experiment_restore.py",
    ],
    # NOTE: This tests Tune and Train V1 restoration.
    env = {"RAY_TRAIN_V2_ENABLED": "0"},
    tags = [
        "exclusive",
        "team:ml",
    ],
    deps = [":ml_lib"],
)

py_test(
    name = "test_errors",
    size = "medium",
    srcs = ["tests/test_errors.py"],
    # NOTE: This tests Tune (Train V1) error propagation logic.
    env = {"RAY_TRAIN_V2_ENABLED": "0"},
    tags = [
        "exclusive",
        "team:ml",
    ],
    deps = [":ml_lib"],
)

py_test(
    name = "test_integration_comet",
    size = "small",
    srcs = ["tests/test_integration_comet.py"],
    # NOTE: This tests the Tune Comet callback.
    env = {"RAY_TRAIN_V2_ENABLED": "0"},
    tags = [
        "exclusive",
        "team:ml",
    ],
    deps = [":ml_lib"],
)

py_test(
    name = "test_integration_wandb",
    size = "small",
    srcs = ["tests/test_integration_wandb.py"],
    # NOTE: This tests the Tune wandb callback.
    env = {"RAY_TRAIN_V2_ENABLED": "0"},
    tags = [
        "exclusive",
        "team:ml",
    ],
    deps = [":ml_lib"],
)

py_test(
    name = "test_integration_mlflow",
    size = "medium",
    srcs = ["tests/test_integration_mlflow.py"],
    # NOTE: This tests the Tune mlflow callback.
    env = {"RAY_TRAIN_V2_ENABLED": "1"},
    tags = [
        "exclusive",
        "team:ml",
    ],
    deps = [":ml_lib"],
)

py_test(
    name = "test_keras_callback",
    size = "medium",
    srcs = ["tests/test_keras_callback.py"],
    env = {"RAY_TRAIN_V2_ENABLED": "1"},
    tags = [
        "exclusive",
        "team:ml",
    ],
    deps = [":ml_lib"],
)

py_test(
    name = "test_remote_storage_hdfs",
    size = "small",
    srcs = ["tests/test_remote_storage_hdfs.py"],
    env = {"RAY_TRAIN_V2_ENABLED": "1"},
    tags = [
        "exclusive",
        "hdfs",
        "team:ml",
    ],
    deps = [":ml_lib"],
)

py_test(
    name = "test_tracebacks",
    size = "small",
    srcs = ["tests/test_tracebacks.py"],
    env = {"RAY_TRAIN_V2_ENABLED": "1"},
    tags = [
        "exclusive",
        "team:ml",
    ],
    deps = [":ml_lib"],
)

py_test(
    name = "test_utils",
    size = "small",
    srcs = ["tests/test_utils.py"],
    env = {"RAY_TRAIN_V2_ENABLED": "1"},
    tags = [
        "exclusive",
        "team:ml",
    ],
    deps = [":ml_lib"],
)

# --------------------------------------------------------------------
# Tests from the python/ray/air/tests/execution directory.
# Covers all tests starting with `test_`.
# Please keep these sorted alphabetically.
# TODO: Move this to Tune.
# --------------------------------------------------------------------

py_test(
    name = "test_barrier",
    size = "small",
    srcs = ["tests/execution/test_barrier.py"],
    env = {"RAY_TRAIN_V2_ENABLED": "1"},
    tags = [
        "exclusive",
        "team:ml",
    ],
    deps = [":ml_lib"],
)

py_test(
    name = "test_e2e_train_flow",
    size = "medium",
    srcs = ["tests/execution/test_e2e_train_flow.py"],
    env = {"RAY_TRAIN_V2_ENABLED": "1"},
    tags = [
        "exclusive",
        "team:ml",
    ],
    deps = [":ml_lib"],
)

py_test(
    name = "test_e2e_tune_flow",
    size = "medium",
    srcs = ["tests/execution/test_e2e_tune_flow.py"],
    env = {"RAY_TRAIN_V2_ENABLED": "1"},
    tags = [
        "exclusive",
        "team:ml",
    ],
    deps = [":ml_lib"],
)

py_test(
    name = "test_event_manager",
    size = "medium",
    srcs = ["tests/execution/test_event_manager.py"],
    env = {"RAY_TRAIN_V2_ENABLED": "1"},
    tags = [
        "exclusive",
        "team:ml",
    ],
    deps = [":ml_lib"],
)

py_test(
    name = "test_resource_manager_fixed",
    size = "small",
    srcs = ["tests/execution/test_resource_manager_fixed.py"],
    env = {"RAY_TRAIN_V2_ENABLED": "1"},
    tags = [
        "exclusive",
        "team:ml",
    ],
    deps = [":ml_lib"],
)

py_test(
    name = "test_resource_manager_placement_group",
    size = "medium",
    srcs = ["tests/execution/test_resource_manager_placement_group.py"],
    env = {"RAY_TRAIN_V2_ENABLED": "1"},
    tags = [
        "exclusive",
        "team:ml",
    ],
    deps = [":ml_lib"],
)

py_test(
    name = "test_resource_request",
    size = "small",
    srcs = ["tests/execution/test_resource_request.py"],
    env = {"RAY_TRAIN_V2_ENABLED": "1"},
    tags = [
        "exclusive",
        "team:ml",
    ],
    deps = [":ml_lib"],
)

py_test(
    name = "test_tracked_actor",
    size = "small",
    srcs = ["tests/execution/test_tracked_actor.py"],
    env = {"RAY_TRAIN_V2_ENABLED": "1"},
    tags = [
        "exclusive",
        "team:ml",
    ],
    deps = [":ml_lib"],
)

py_test(
    name = "test_tracked_actor_task",
    size = "small",
    srcs = ["tests/execution/test_tracked_actor_task.py"],
    env = {"RAY_TRAIN_V2_ENABLED": "1"},
    tags = [
        "exclusive",
        "team:ml",
    ],
    deps = [":ml_lib"],
)

# This is a dummy test dependency that causes the above tests to be
# re-run if any of these files changes.
py_library(
    name = "ml_lib",
    srcs = glob(
        ["**/*.py"],
        exclude = ["tests/*.py"],
    ),
    visibility = [
        "//python/ray/air:__pkg__",
        "//python/ray/air:__subpackages__",
        "//python/ray/train:__pkg__",
        "//python/ray/train:__subpackages__",
        "//release:__pkg__",
    ],
)
